import jax
import equinox as eqx
import equinox.nn as nn
import jax.numpy as jnp
import typing as tpImplementing GPT2 in JAX for fun 🦀🦀🦀
1 GPT2 for JAX 🚀
Explore the full project on the GitHub repository.
1.1 Context ✍️
This project involves rewriting XTTS in JAX to better understand its architecture and functionality. Originally developed by the now-defunct Coqai company, XTTS is a Text-to-Speech model. We’ll recreate its generative component using a GPT2 architecture—a decoder-only transformer—based on (Radford et al. 2019). The implementation closely follows this tutorial.
1.2 GPT2 in Text-to-Speech
1.2.1 What are we building?
Our goal is to generate sequences of tokens for audio synthesis. Specifically, we aim to produce “audio tokens,” small units of audio, discovered using a VQVAE. By learning to map text tokens to audio tokens, the model becomes multi-modal.
The final output sequences represent speech, which we convert into audio using HiFiGAN. Additionally, we enhance speech expressiveness (e.g., tone, speed) by feeding 1024-dimensional vectors representing the target speaker’s paralinguistic features.
1.2.2 Under the Hood
Masked Attention
Masked attention is the core mechanism for learning relationships between tokens. It determines which tokens influence others by projecting them into smaller dimensions and computing relationships. Masking ensures the model focuses only on prior tokens, preventing it from “seeing” future ones.
Studies classify attention patterns into:
1. Semantic: Tokens linked by meaning.
2. Linguistic: Tokens connected by grammar (e.g., verbs and nouns).
3. Rare Tokens: Infrequent but critical tokens.
Feedforward Layers
Feedforward layers mix outputs, add non-linearity via activation functions, and stack layers for hierarchical abstractions. The final output approximates a one-hot encoding in the token vocabulary, enabling token selection for sequential generation.
1.3 Goal 🎯
Implement a GPT2 architecture using Equinox and train it on TinyStories.
2 Model
We have a few things to implement from the ground up. The custom activation function, the forward layer, the masked attention. We then package this up in a nice layer that we can stack, and finally wrap all these stacks into a GPT2 !
We can start by importing our favorite libraries 🥰
2.1 Configuration file
Because of the size of our model, we’re going to be passing down lots of arguments. To avoid having a long unreadable list of parameters we can define a “dataclass” that will allow us to simply pass a config down to the model.
Feel free to experiment with various settings !
from dataclasses import dataclass
@dataclass
class GPTConfig:
block_size: int = 100
vocab_size: int = (
50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
)
n_layer: int = 16
n_head: int = 12
n_embd: int = 1024
dropout: float = 0.1
bias: bool = True #2.2 SwiGLU Activation Function
We start by implementing the SwiGLU activation function, introduced in (Shazeer 2020), a powerful variant of GLU.
2.2.1 Why SwiGLU?
SwiGLU dynamically adjusts its activation based on the input. Think of it like a railway switch—redirecting the “activation path” when the input carries different information. This gives the network greater flexibility and control, leading to better performance.
For more details, see this explanation by Boudefel.
Below is a visualization of the Swish function, \(x \times \text{sigmoid}(x)\), which plays a role in SwiGLU:
class SwiGLU(eqx.Module):
W: nn.Linear
V: nn.Linear
b: jax.Array
c: jax.Array
def __init__(self, input_dim, output_dim, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
self.W = nn.Linear(input_dim, output_dim, key=key1)
self.V = nn.Linear(input_dim, output_dim, key=key2)
self.b = jax.random.normal(key3, (output_dim))
self.c = jax.random.normal(key4, (output_dim))
@eqx.filter_jit
def __call__(self, x):
return jax.nn.swish((self.W(x) + self.b) * (self.V(x) + self.c))Code
key = jax.random.PRNGKey(69)
mod = SwiGLU(10, 4, key)
x = jnp.ones(10)
print(mod(x).shape)2.3 MLP
We can now move onto the multilayer perceptron, which we mentionned earlier as the feedforward part of our network. Because the model is big and we want to make sure that it doesn’t just “memorize” things, we include dropout which pushes the model to avoid relying on singular neurons / data flowing through for information.
✨ You’ll also notice that since our SwiGLU has two linear layers in it, in reality each MLP that we’ll use uses 4 layers !!
class MLP(eqx.Module):
ff1: nn.Linear
ff2: nn.Linear
act: SwiGLU
drop: nn.Dropout
def __init__(self, config, key):
key1, key2, key3 = jax.random.split(key, 3)
self.ff1 = nn.Linear(
config.n_embd, 4 * config.n_embd, use_bias=config.bias, key=key1
)
self.act = SwiGLU(4 * config.n_embd, 4 * config.n_embd, key=key2)
self.ff2 = nn.Linear(
4 * config.n_embd, config.n_embd, use_bias=config.bias, key=key3
)
self.drop = nn.Dropout(config.dropout, deterministic=True)
@eqx.filter_jit
def __call__(self, x):
y = self.ff1(x)
y = self.act(y)
y = self.ff2(y)
return self.drop(y)2.4 Masked attention
Moving onto one of the more complicated aspects of the model, but in the end it simply learns to output which tokens are more important with each other. There are plenty of fantastic tutorials out there for better understanding the underlying concept, notably : Transformers explained visually
Again, our query, key and values are all tokens that have passed a Feedforward layer to then try to find relationships between them:
import math
class CausalSelfAttention(eqx.Module):
attnk: nn.Linear
attnq: nn.Linear
attnv: nn.Linear
proj: nn.Linear
resid_dropout: nn.Dropout
attn_dropout: nn.Dropout
mask: jax.Array
def __init__(self, config, key):
key1, key2, key3, key4 = jax.random.split(key, 4)
self.attnk = nn.Linear(
config.n_embd, config.n_embd, use_bias=config.bias, key=key1
)
self.attnv = nn.Linear(
config.n_embd, config.n_embd, use_bias=config.bias, key=key2
)
self.attnq = nn.Linear(
config.n_embd, config.n_embd, use_bias=config.bias, key=key3
)
self.attn_dropout = nn.Dropout(config.dropout, deterministic=True)
self.resid_dropout = nn.Dropout(config.dropout, deterministic=True)
self.proj = nn.Linear(
config.n_embd, config.n_embd, use_bias=config.bias, key=key4
)
self.mask = jnp.tril(jnp.ones((config.block_size, config.block_size)))
# Could play arround with the different attention score calculations (Baidhu ?)
# X is an embedding, it should self attend.
@eqx.filter_jit
def __call__(self, x):
# x = jnp.swapaxes(x, -1, -2)
T, C = x.shape # Seq length and embedding dim.
q = jax.vmap(self.attnq)(x)
k = jax.vmap(self.attnk)(x)
v = jax.vmap(self.attnv)(x)
att = jnp.matmul(q, jnp.transpose(k)) / math.sqrt(jnp.shape(k)[-1])
att = jnp.where(
jax.numpy.equal(jax.lax.stop_gradient(self.mask[:T, :T]), 0),
float("-inf"),
att,
)
att = jax.nn.softmax(att, axis=-1)
att = self.attn_dropout(att)
y = jnp.matmul(att, v)
y = jax.vmap(self.proj)(y)
y = self.resid_dropout(y)
return ySmall check…
Code
import optax
config = GPTConfig()
key = jax.random.PRNGKey(69)
mlp = CausalSelfAttention(config, key)
optimizer = optax.adam(1e-5)
opt_state = optimizer.init(mlp)
x = jax.random.normal(jax.random.key(2), (30, config.n_embd))
@eqx.filter_jit
def calculate_loss(model, x, y):
output = model(x)
return jax.numpy.mean(jax.numpy.abs(y - output))
def make_step(model, opt_state, x, y):
loss_step, grads = eqx.filter_value_and_grad(calculate_loss)(model, x, y)
updates, optimizer_state = optimizer.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, optimizer_state, loss_step
print(make_step(mlp, opt_state, x, x))
print(mlp(jax.random.normal(key, (100, config.n_embd))).shape)2.5 Block
Ok ! Now that we have the component parts of what we call a “block” we can assemble them. This will then be stacked to get as many layers of abstraction as we wish. In our case we will stack it 12 times as per the GPTConfig we defined.
class Block(eqx.Module):
norm: nn.LayerNorm
attn: CausalSelfAttention
mlp: MLP
def __init__(self, config, key):
key1, key2 = jax.random.split(key, 2)
self.norm = nn.LayerNorm(config.n_embd, use_bias=config.bias)
self.attn = CausalSelfAttention(config, key=key1)
self.mlp = MLP(config, key=key2)
@eqx.filter_jit
def __call__(self, x):
y = jax.vmap(self.norm)(x)
y = self.attn(
y
) # Can't vmap as the whole point is exchange info between tokens.
x = y + x
y = jax.vmap(self.norm)(x)
y = jax.vmap(self.mlp)(y)
x = y + x
return xCan compare with their work.
Code
import optax
config = GPTConfig()
key = jax.random.PRNGKey(69)
block = Block(config, key)
optimizer = optax.adam(1e-5)
opt_state = optimizer.init(block)
x = jax.random.normal(jax.random.key(2), (30, config.n_embd))
@eqx.filter_jit
def calculate_loss(model, x, y):
output = model(x)
return jax.numpy.mean(jax.numpy.abs(y - output))
def make_step(model, opt_state, x, y):
loss_step, grads = eqx.filter_value_and_grad(calculate_loss)(model, x, y)
updates, optimizer_state = optimizer.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, optimizer_state, loss_step
print(make_step(block, opt_state, x, x))We can finally add the embeddings to our model, which are the maps that send tokens to the dimension that the model works with, i.e. 1024 dims.
class GPT(eqx.Module):
wte: nn.Embedding # Token embeddings
wpe: nn.Embedding # Positional embeddings
drop: nn.Dropout
layers: list
norm: nn.LayerNorm
lm_head: nn.Linear
def __init__(self, config, key):
key1, key2, key3, key4, key5 = jax.random.split(key, 5)
self.wte = nn.Embedding(config.vocab_size, config.n_embd, key=key1)
self.wpe = nn.Embedding(config.block_size, config.n_embd, key=key2)
self.drop = nn.Dropout(config.dropout, deterministic=True)
self.layers = [
Block(config, key) for key in jax.random.split(key3, config.n_layer)
]
self.norm = nn.LayerNorm(config.n_embd, use_bias=config.bias)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, key=key4)
@eqx.filter_jit
def __call__(self, token_ids):
(t,) = token_ids.shape
# Should use better positional embeddings with cos and sin.
pos = jnp.arange(0, t)
tok_emb = jax.vmap(self.wte)(token_ids)
pos_emb = jax.vmap(self.wpe)(pos)
# Dropout at the first layer ? Seems a bit aggressive...
x = self.drop(tok_emb + pos_emb)
for block in self.layers:
x = block(x)
x = jax.vmap(self.norm)(x)
logits = jax.vmap(self.lm_head)(x)
# logits = jax.nn.softmax(logits)
return logitsCode
import optax
config = GPTConfig()
key = jax.random.PRNGKey(69)
block = GPT(config, key)
optimizer = optax.adam(1e-5)
opt_state = optimizer.init(block)
x = jax.numpy.ones((30, 128), dtype=jax.numpy.int32)
# @eqx.filter_jit
def calculate_loss(model, x, y):
output = jax.vmap(model)(x)
return jax.numpy.mean(
jax.vmap(optax.softmax_cross_entropy_with_integer_labels)(output, y)
)
def make_step(model, opt_state, x, y):
loss_step, grads = eqx.filter_value_and_grad(calculate_loss)(model, x, y)
updates, optimizer_state = optimizer.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, optimizer_state, loss_step
print(make_step(block, opt_state, x, x))3 Training
We can now move onto training the model ! We’re going to be using the TinyStories (Eldan and Li 2023) dataset. Tiktoken is used to map the sentences to sequences of tokens that the model would understand. Below is the code to download and transform the data into a binary file, and then provide it with a dataloader to our training regime.
Example Tinystories:
Once upon a time, in a small house, there lived a boy named Tim. Tim was a selfish boy. He did not like to share his things with others. One day, his mom brought home a big bag of vegetables. She wanted to...
Code
# saves the openwebtext dataset to a binary file for training. following was helpful:
# https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py
import os
from tqdm import tqdm
import numpy as np
import tiktoken
from datasets import load_dataset # huggingface datasets
# number of workers in .map() call
# good number to use is ~order number of cpu cores // 2
num_proc = 16
dataset = load_dataset("roneneldan/TinyStories")
# we now want to tokenize the dataset. first define the encoding function (gpt2 bpe)
enc = tiktoken.get_encoding("gpt2")
def process(example):
ids = enc.encode_ordinary(
example["text"]
) # encode_ordinary ignores any special tokens
ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe
# note: I think eot should be prepended not appended... hmm. it's called "eot" though...
out = {"ids": ids, "len": len(ids)}
return out
# tokenize the dataset
tokenized = dataset.map(
process,
remove_columns=["text"],
desc="tokenizing the splits",
num_proc=num_proc,
)
# concatenate all the ids in each dataset into one large file we can use for training
for split, dset in tokenized.items():
arr_len = np.sum(dset["len"])
filename = os.path.join(os.path.dirname("dataset"), f"{split}.bin")
dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
arr = np.memmap(filename, dtype=dtype, mode="w+", shape=(arr_len,))
total_batches = 1024
idx = 0
for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"):
# Batch together samples for faster write
batch = dset.shard(
num_shards=total_batches, index=batch_idx, contiguous=True
).with_format("numpy")
arr_batch = np.concatenate(batch["ids"])
# Write into mmap
arr[idx : idx + len(arr_batch)] = arr_batch
idx += len(arr_batch)
arr.flush()We can now load the code from the compressed binary representation to the inputs and outputs. Since we want the GPT to learn to predict the next token, we simply shift the input by 1 !
Code
import os
import jax.numpy as jnp
import numpy
data_dir = "dataset"
config = GPTConfig()
def get_batch(split: str):
# We recreate jnp.memmap every batch to avoid a memory leak, as per
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
if split == "train":
data = numpy.memmap(
os.path.join(data_dir, "train.bin"), dtype=numpy.uint16, mode="r"
)
else:
data = numpy.memmap(
os.path.join(data_dir, "validation.bin"), dtype=numpy.uint16, mode="r"
)
ix = numpy.random.randint(len(data) - config.block_size, size=(32,))
x = jnp.stack([jnp.array(data[i : i + config.block_size]) for i in ix])
y = jnp.stack([jnp.array(data[i + 1 : i + 1 + config.block_size]) for i in ix])
return x, yWe can now define our loss function. Our goal here is to motivate the model to output something close to [0, 0, 0, …, 1,…, 0, 0] where the 1 is placed at the \(n\)th index. This index would ideally correspond to the word we’re attempting to match. optax, the ML optimisation library of JAX conveniently has a function for this.
import optax
learning_rate = 1e-4
warmup_iters = 3
init_from = "scratch"
lr_decay_iters = 20
iter_num = 0
min_lr = 1e-6
lr_scheduler = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=learning_rate,
warmup_steps=warmup_iters if init_from == "scratch" else 0,
decay_steps=lr_decay_iters - iter_num,
end_value=min_lr,
)
optimizer = optax.inject_hyperparams(optax.adamw)(learning_rate=learning_rate)
# optimizer = optax.adamw(1e-2)
@eqx.filter_jit
def calculate_loss(model, x, y):
output = jax.vmap(model)(x)
return jax.numpy.mean(optax.softmax_cross_entropy_with_integer_labels(output, y))
def make_step(model, optimizer_state, x, y):
loss, grads = eqx.filter_value_and_grad(calculate_loss)(model, x, y)
updates, optimizer_state = optimizer.update(grads, optimizer_state, model)
model = eqx.apply_updates(model, updates)
return model, optimizer_state, lossWe can now move onto initializing our model and training it ! We can log the progress on wandb to see the loss curve.
I train two seperate models, below are the various parameters I changed:
| Experiment ID | Learning Rate | Block size | Dropout | N_embed | Layers |
|---|---|---|---|---|---|
| Run1 | 1e-4 | 100 | 0.1 | 1024 | 16 |
| Run2 | 1e-4 | 100 | 0.0 | 512 | 12 |
import tiktoken
import wandb
key = jax.random.PRNGKey(69)
gptconf = GPTConfig()
model = GPT(gptconf, key)
wandb.init(project="gpt-training", config=gptconf.__dict__)
optimizer_state = optimizer.init(model)
num_iterations = 1000
enc = tiktoken.get_encoding("gpt2")
for local_iter_num in range(num_iterations):
x, y = get_batch("train")
# Perform a single training step
model, optimizer_state, loss = make_step(model, optimizer_state, x, y)
wandb.log({"loss": loss, "iteration": local_iter_num})
print(f"loss {loss}, iteration {local_iter_num}")Different experiments show varying results:
After training, we can save the model to a local directory to then use it for inference. I quickly check whether the model produces gibberish or not :
eqx.tree_serialise_leaves("gpt2.eqx", model)With the code below, we get a linguistically acceptable output:
Once upon a time, there was a little girl named Lily. She was so happy.
The little girl was so happy to the park
enc = tiktoken.get_encoding("gpt2")
# print(enc.special_tokens_set)
start = "Once upon"
x = jax.numpy.array([enc.encode(start)])
while x[0, -1] != enc.eot_token:
logits = jax.vmap(model)(x)
x = jax.numpy.concat(
[x, jax.numpy.array([[jax.numpy.argmax(logits[0, -1])]])], axis=-1
)
print(enc.decode(jax.numpy.squeeze(x, axis=0)))This concludes the training of a rudimentary GPT2 model. To go further, we could implement key-value caching which allows us to reduce redundant calculations. We could also implement beam search to find more optimal sequences of words, and use top-k and top-p sampling to add diversity along with temperature. This is left as an exercise to the reader